# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

"""IIR and FIR filtering and resampling functions."""

from collections import Counter
from copy import deepcopy
from functools import partial
from math import gcd

import numpy as np
from scipy import fft, signal
from scipy.stats import f as fstat

from ._fiff.pick import _picks_to_idx
from ._ola import _COLA
from .cuda import (
    _fft_multiply_repeated,
    _fft_resample,
    _setup_cuda_fft_multiply_repeated,
    _setup_cuda_fft_resample,
    _smart_pad,
)
from .fixes import minimum_phase
from .parallel import parallel_func
from .utils import (
    _check_option,
    _check_preload,
    _ensure_int,
    _pl,
    _validate_type,
    logger,
    sum_squared,
    verbose,
    warn,
)

# These values from Ifeachor and Jervis.
_length_factors = dict(hann=3.1, hamming=3.3, blackman=5.0)


def next_fast_len(target):
    """Find the next fast size of input data to `fft`, for zero-padding, etc.

    SciPy's FFTPACK has efficient functions for radix {2, 3, 4, 5}, so this
    returns the next composite of the prime factors 2, 3, and 5 which is
    greater than or equal to `target`. (These are also known as 5-smooth
    numbers, regular numbers, or Hamming numbers.)

    Parameters
    ----------
    target : int
        Length to start searching from.  Must be a positive integer.

    Returns
    -------
    out : int
        The first 5-smooth number greater than or equal to `target`.

    Notes
    -----
    Copied from SciPy with minor modifications.
    """
    from bisect import bisect_left

    hams = (
        8,
        9,
        10,
        12,
        15,
        16,
        18,
        20,
        24,
        25,
        27,
        30,
        32,
        36,
        40,
        45,
        48,
        50,
        54,
        60,
        64,
        72,
        75,
        80,
        81,
        90,
        96,
        100,
        108,
        120,
        125,
        128,
        135,
        144,
        150,
        160,
        162,
        180,
        192,
        200,
        216,
        225,
        240,
        243,
        250,
        256,
        270,
        288,
        300,
        320,
        324,
        360,
        375,
        384,
        400,
        405,
        432,
        450,
        480,
        486,
        500,
        512,
        540,
        576,
        600,
        625,
        640,
        648,
        675,
        720,
        729,
        750,
        768,
        800,
        810,
        864,
        900,
        960,
        972,
        1000,
        1024,
        1080,
        1125,
        1152,
        1200,
        1215,
        1250,
        1280,
        1296,
        1350,
        1440,
        1458,
        1500,
        1536,
        1600,
        1620,
        1728,
        1800,
        1875,
        1920,
        1944,
        2000,
        2025,
        2048,
        2160,
        2187,
        2250,
        2304,
        2400,
        2430,
        2500,
        2560,
        2592,
        2700,
        2880,
        2916,
        3000,
        3072,
        3125,
        3200,
        3240,
        3375,
        3456,
        3600,
        3645,
        3750,
        3840,
        3888,
        4000,
        4050,
        4096,
        4320,
        4374,
        4500,
        4608,
        4800,
        4860,
        5000,
        5120,
        5184,
        5400,
        5625,
        5760,
        5832,
        6000,
        6075,
        6144,
        6250,
        6400,
        6480,
        6561,
        6750,
        6912,
        7200,
        7290,
        7500,
        7680,
        7776,
        8000,
        8100,
        8192,
        8640,
        8748,
        9000,
        9216,
        9375,
        9600,
        9720,
        10000,
    )

    if target <= 6:
        return target

    # Quickly check if it's already a power of 2
    if not (target & (target - 1)):
        return target

    # Get result quickly for small sizes, since FFT itself is similarly fast.
    if target <= hams[-1]:
        return hams[bisect_left(hams, target)]

    match = float("inf")  # Anything found will be smaller
    p5 = 1
    while p5 < target:
        p35 = p5
        while p35 < target:
            # Ceiling integer division, avoiding conversion to float
            # (quotient = ceil(target / p35))
            quotient = -(-target // p35)

            p2 = 2 ** int(quotient - 1).bit_length()

            N = p2 * p35
            if N == target:
                return N
            elif N < match:
                match = N
            p35 *= 3
            if p35 == target:
                return p35
        if p35 < match:
            match = p35
        p5 *= 5
        if p5 == target:
            return p5
    if p5 < match:
        match = p5
    return match


def _overlap_add_filter(
    x,
    h,
    n_fft=None,
    phase="zero",
    picks=None,
    n_jobs=None,
    copy=True,
    pad="reflect_limited",
):
    """Filter the signal x using h with overlap-add FFTs."""
    # set up array for filtering, reshape to 2D, operate on last axis
    x, orig_shape, picks = _prep_for_filtering(x, copy, picks)
    # Extend the signal by mirroring the edges to reduce transient filter
    # response
    _check_zero_phase_length(len(h), phase)
    if len(h) == 1:
        return x * h**2 if phase == "zero-double" else x * h
    n_edge = max(min(len(h), x.shape[1]) - 1, 0)
    logger.debug(f"Smart-padding with:  {n_edge} samples on each edge")
    n_x = x.shape[1] + 2 * n_edge

    if phase == "zero-double":
        h = np.convolve(h, h[::-1])

    # Determine FFT length to use
    min_fft = 2 * len(h) - 1
    if n_fft is None:
        max_fft = n_x
        if max_fft >= min_fft:
            # cost function based on number of multiplications
            N = 2 ** np.arange(
                np.ceil(np.log2(min_fft)), np.ceil(np.log2(max_fft)) + 1, dtype=int
            )
            cost = (
                np.ceil(n_x / (N - len(h) + 1).astype(np.float64))
                * N
                * (np.log2(N) + 1)
            )

            # add a heuristic term to prevent too-long FFT's which are slow
            # (not predicted by mult. cost alone, 4e-5 exp. determined)
            cost += 4e-5 * N * n_x

            n_fft = N[np.argmin(cost)]
        else:
            # Use only a single block
            n_fft = next_fast_len(min_fft)
    logger.debug(f"FFT block length:   {n_fft}")
    if n_fft < min_fft:
        raise ValueError(
            f"n_fft is too short, has to be at least 2 * len(h) - 1 ({min_fft}), got "
            f"{n_fft}"
        )

    # Figure out if we should use CUDA
    n_jobs, cuda_dict = _setup_cuda_fft_multiply_repeated(n_jobs, h, n_fft)

    # Process each row separately
    picks = _picks_to_idx(len(x), picks)
    parallel, p_fun, _ = parallel_func(_1d_overlap_filter, n_jobs)
    if n_jobs == 1:
        for p in picks:
            x[p] = _1d_overlap_filter(
                x[p], len(h), n_edge, phase, cuda_dict, pad, n_fft
            )
    else:
        data_new = parallel(
            p_fun(x[p], len(h), n_edge, phase, cuda_dict, pad, n_fft) for p in picks
        )
        for pp, p in enumerate(picks):
            x[p] = data_new[pp]

    x.shape = orig_shape
    return x


def _1d_overlap_filter(x, n_h, n_edge, phase, cuda_dict, pad, n_fft):
    """Do one-dimensional overlap-add FFT FIR filtering."""
    # pad to reduce ringing
    x_ext = _smart_pad(x, (n_edge, n_edge), pad)
    n_x = len(x_ext)
    x_filtered = np.zeros_like(x_ext)

    n_seg = n_fft - n_h + 1
    n_segments = int(np.ceil(n_x / float(n_seg)))
    shift = ((n_h - 1) // 2 if phase.startswith("zero") else 0) + n_edge

    # Now the actual filtering step is identical for zero-phase (filtfilt-like)
    # or single-pass
    for seg_idx in range(n_segments):
        start = seg_idx * n_seg
        stop = (seg_idx + 1) * n_seg
        seg = x_ext[start:stop]
        seg = np.concatenate([seg, np.zeros(n_fft - len(seg))])

        prod = _fft_multiply_repeated(seg, cuda_dict)

        start_filt = max(0, start - shift)
        stop_filt = min(start - shift + n_fft, n_x)
        start_prod = max(0, shift - start)
        stop_prod = start_prod + stop_filt - start_filt
        x_filtered[start_filt:stop_filt] += prod[start_prod:stop_prod]

    # Remove mirrored edges that we added and cast (n_edge can be zero)
    x_filtered = x_filtered[: n_x - 2 * n_edge].astype(x.dtype)
    return x_filtered


def _filter_attenuation(h, freq, gain):
    """Compute minimum attenuation at stop frequency."""
    _, filt_resp = signal.freqz(h.ravel(), worN=np.pi * freq)
    filt_resp = np.abs(filt_resp)  # use amplitude response
    filt_resp[np.where(gain == 1)] = 0
    idx = np.argmax(filt_resp)
    att_db = -20 * np.log10(np.maximum(filt_resp[idx], 1e-20))
    att_freq = freq[idx]
    return att_db, att_freq


def _prep_for_filtering(x, copy, picks=None):
    """Set up array as 2D for filtering ease."""
    x = _check_filterable(x)
    if copy is True:
        x = x.copy()
    orig_shape = x.shape
    x = np.atleast_2d(x)
    picks = _picks_to_idx(x.shape[-2], picks)
    x.shape = (np.prod(x.shape[:-1]), x.shape[-1])
    if len(orig_shape) == 3:
        n_epochs, n_channels, n_times = orig_shape
        offset = np.repeat(np.arange(0, n_channels * n_epochs, n_channels), len(picks))
        picks = np.tile(picks, n_epochs) + offset
    elif len(orig_shape) > 3:
        raise ValueError(
            "picks argument is not supported for data with more than three dimensions"
        )
    assert all(0 <= pick < x.shape[0] for pick in picks)  # guaranteed by above

    return x, orig_shape, picks


def _firwin_design(N, freq, gain, window, sfreq):
    """Construct a FIR filter using firwin."""
    assert freq[0] == 0
    assert len(freq) > 1
    assert len(freq) == len(gain)
    assert N % 2 == 1
    h = np.zeros(N)
    prev_freq = freq[-1]
    prev_gain = gain[-1]
    if gain[-1] == 1:
        h[N // 2] = 1  # start with "all up"
    assert prev_gain in (0, 1)
    for this_freq, this_gain in zip(freq[::-1][1:], gain[::-1][1:]):
        assert this_gain in (0, 1)
        if this_gain != prev_gain:
            # Get the correct N to satisfy the requested transition bandwidth
            transition = (prev_freq - this_freq) / 2.0
            this_N = int(round(_length_factors[window] / transition))
            this_N += 1 - this_N % 2  # make it odd
            if this_N > N:
                raise ValueError(
                    f"The requested filter length {N} is too short for the requested "
                    f"{transition * sfreq / 2.0:0.2f} Hz transition band, which "
                    f"requires {this_N} samples"
                )
            # Construct a lowpass
            this_h = signal.firwin(
                this_N,
                (prev_freq + this_freq) / 2.0,
                window=window,
                pass_zero=True,
                fs=freq[-1] * 2,
            )
            assert this_h.shape == (this_N,)
            offset = (N - this_N) // 2
            if this_gain == 0:
                h[offset : N - offset] -= this_h
            else:
                h[offset : N - offset] += this_h
        prev_gain = this_gain
        prev_freq = this_freq
    return h


def _construct_fir_filter(
    sfreq, freq, gain, filter_length, phase, fir_window, fir_design
):
    """Filter signal using gain control points in the frequency domain.

    The filter impulse response is constructed from a Hann window (window
    used in "firwin2" function) to avoid ripples in the frequency response
    (windowing is a smoothing in frequency domain).

    If x is multi-dimensional, this operates along the last dimension.
    """
    assert freq[0] == 0
    if fir_design == "firwin2":
        fir_design = signal.firwin2
    else:
        assert fir_design == "firwin"
        fir_design = partial(_firwin_design, sfreq=sfreq)
    # issue a warning if attenuation is less than this
    min_att_db = 12 if phase == "minimum-half" else 20

    # normalize frequencies
    freq = np.array(freq) / (sfreq / 2.0)
    if freq[0] != 0 or freq[-1] != 1:
        raise ValueError(
            f"freq must start at 0 and end an Nyquist ({sfreq / 2.0}), got {freq}"
        )
    gain = np.array(gain)

    # Use overlap-add filter with a fixed length
    N = _check_zero_phase_length(filter_length, phase, gain[-1])
    # construct symmetric (linear phase) filter
    if phase == "minimum-half":
        h = fir_design(N * 2 - 1, freq, gain, window=fir_window)
        h = minimum_phase(h)
    else:
        h = fir_design(N, freq, gain, window=fir_window)
        if phase == "minimum":
            h = minimum_phase(h, half=False)
    assert h.size == N
    att_db, att_freq = _filter_attenuation(h, freq, gain)
    if phase == "zero-double":
        att_db += 6
    if att_db < min_att_db:
        att_freq *= sfreq / 2.0
        warn(
            f"Attenuation at stop frequency {att_freq:0.2f} Hz is only {att_db:0.2f} "
            "dB. Increase filter_length for higher attenuation."
        )
    return h


def _check_zero_phase_length(N, phase, gain_nyq=0):
    N = int(N)
    if N % 2 == 0:
        if phase == "zero":
            raise RuntimeError(f'filter_length must be odd if phase="zero", got {N}')
        elif phase == "zero-double" and gain_nyq == 1:
            N += 1
    return N


def _check_coefficients(system):
    """Check for filter stability."""
    if isinstance(system, tuple):
        z, p, k = signal.tf2zpk(*system)
    else:  # sos
        z, p, k = signal.sos2zpk(system)
    if np.any(np.abs(p) > 1.0):
        raise RuntimeError(
            "Filter poles outside unit circle, filter will be "
            "unstable. Consider using different filter "
            "coefficients."
        )


def _iir_filter(x, iir_params, picks, n_jobs, copy, phase="zero"):
    """Call filtfilt or lfilter."""
    # set up array for filtering, reshape to 2D, operate on last axis
    x, orig_shape, picks = _prep_for_filtering(x, copy, picks)
    if phase in ("zero", "zero-double"):
        padlen = min(iir_params["padlen"], x.shape[-1] - 1)
        if "sos" in iir_params:
            fun = partial(
                _iir_pad_apply_unpad,
                func=signal.sosfiltfilt,
                sos=iir_params["sos"],
                padlen=padlen,
                padtype="reflect_limited",
            )
            _check_coefficients(iir_params["sos"])
        else:
            fun = partial(
                _iir_pad_apply_unpad,
                func=signal.filtfilt,
                b=iir_params["b"],
                a=iir_params["a"],
                padlen=padlen,
                padtype="reflect_limited",
            )
            _check_coefficients((iir_params["b"], iir_params["a"]))
    else:
        if "sos" in iir_params:
            fun = partial(signal.sosfilt, sos=iir_params["sos"], axis=-1)
            _check_coefficients(iir_params["sos"])
        else:
            fun = partial(signal.lfilter, b=iir_params["b"], a=iir_params["a"], axis=-1)
            _check_coefficients((iir_params["b"], iir_params["a"]))
    parallel, p_fun, n_jobs = parallel_func(fun, n_jobs)
    if n_jobs == 1:
        for p in picks:
            x[p] = fun(x=x[p])
    else:
        data_new = parallel(p_fun(x=x[p]) for p in picks)
        for pp, p in enumerate(picks):
            x[p] = data_new[pp]
    x.shape = orig_shape
    return x


def estimate_ringing_samples(system, max_try=100000):
    """Estimate filter ringing.

    Parameters
    ----------
    system : tuple | ndarray
        A tuple of (b, a) or ndarray of second-order sections coefficients.
    max_try : int
        Approximate maximum number of samples to try.
        This will be changed to a multiple of 1000.

    Returns
    -------
    n : int
        The approximate ringing.
    """
    if isinstance(system, tuple):  # TF
        kind = "ba"
        b, a = system
        zi = [0.0] * (len(a) - 1)
    else:
        kind = "sos"
        sos = system
        zi = [[0.0] * 2] * len(sos)
    n_per_chunk = 1000
    n_chunks_max = int(np.ceil(max_try / float(n_per_chunk)))
    x = np.zeros(n_per_chunk)
    x[0] = 1
    last_good = n_per_chunk
    thresh_val = 0
    for ii in range(n_chunks_max):
        if kind == "ba":
            h, zi = signal.lfilter(b, a, x, zi=zi)
        else:
            h, zi = signal.sosfilt(sos, x, zi=zi)
        x[0] = 0  # for subsequent iterations we want zero input
        h = np.abs(h)
        thresh_val = max(0.001 * np.max(h), thresh_val)
        idx = np.where(np.abs(h) > thresh_val)[0]
        if len(idx) > 0:
            last_good = idx[-1]
        else:  # this iteration had no sufficiently lange values
            idx = (ii - 1) * n_per_chunk + last_good
            break
    else:
        warn("Could not properly estimate ringing for the filter")
        idx = n_per_chunk * n_chunks_max
    return idx


_ftype_dict = {
    "butter": "Butterworth",
    "cheby1": "Chebyshev I",
    "cheby2": "Chebyshev II",
    "ellip": "Cauer/elliptic",
    "bessel": "Bessel/Thomson",
}


@verbose
def construct_iir_filter(
    iir_params,
    f_pass=None,
    f_stop=None,
    sfreq=None,
    btype=None,
    return_copy=True,
    *,
    phase="zero",
    verbose=None,
):
    """Use IIR parameters to get filtering coefficients.

    This function works like a wrapper for iirdesign and iirfilter in
    scipy.signal to make filter coefficients for IIR filtering. It also
    estimates the number of padding samples based on the filter ringing.
    It creates a new iir_params dict (or updates the one passed to the
    function) with the filter coefficients ('b' and 'a') and an estimate
    of the padding necessary ('padlen') so IIR filtering can be performed.

    Parameters
    ----------
    iir_params : dict
        Dictionary of parameters to use for IIR filtering.

            * If ``iir_params['sos']`` exists, it will be used as
              second-order sections to perform IIR filtering.

              .. versionadded:: 0.13

            * Otherwise, if ``iir_params['b']`` and ``iir_params['a']``
              exist, these will be used as coefficients to perform IIR
              filtering.
            * Otherwise, if ``iir_params['order']`` and
              ``iir_params['ftype']`` exist, these will be used with
              `scipy.signal.iirfilter` to make a filter.
              You should also supply ``iir_params['rs']`` and
              ``iir_params['rp']`` if using elliptic or Chebychev filters.
            * Otherwise, if ``iir_params['gpass']`` and
              ``iir_params['gstop']`` exist, these will be used with
              `scipy.signal.iirdesign` to design a filter.
            * ``iir_params['padlen']`` defines the number of samples to pad
              (and an estimate will be calculated if it is not given).
              See Notes for more details.
            * ``iir_params['output']`` defines the system output kind when
              designing filters, either "sos" or "ba". For 0.13 the
              default is 'ba' but will change to 'sos' in 0.14.

    f_pass : float or list of float
        Frequency for the pass-band. Low-pass and high-pass filters should
        be a float, band-pass should be a 2-element list of float.
    f_stop : float or list of float
        Stop-band frequency (same size as f_pass). Not used if 'order' is
        specified in iir_params.
    sfreq : float | None
        The sample rate.
    btype : str
        Type of filter. Should be 'lowpass', 'highpass', or 'bandpass'
        (or analogous string representations known to
        :func:`scipy.signal.iirfilter`).
    return_copy : bool
        If False, the 'sos', 'b', 'a', and 'padlen' entries in
        ``iir_params`` will be set inplace (if they weren't already).
        Otherwise, a new ``iir_params`` instance will be created and
        returned with these entries.
    phase : str
        Phase of the filter.
        ``phase='zero'`` (default) or equivalently ``'zero-double'`` constructs and
        applies IIR filter twice, once forward, and once backward (making it non-causal)
        using :func:`~scipy.signal.filtfilt`; ``phase='forward'`` will apply
        the filter once in the forward (causal) direction using
        :func:`~scipy.signal.lfilter`.

        .. versionadded:: 0.13
    %(verbose)s

    Returns
    -------
    iir_params : dict
        Updated iir_params dict, with the entries (set only if they didn't
        exist before) for 'sos' (or 'b', 'a'), and 'padlen' for
        IIR filtering.

    See Also
    --------
    mne.filter.filter_data
    mne.io.Raw.filter

    Notes
    -----
    This function triages calls to :func:`scipy.signal.iirfilter` and
    :func:`scipy.signal.iirdesign` based on the input arguments (see
    linked functions for more details).

    .. versionchanged:: 0.14
       Second-order sections are used in filter design by default (replacing
       ``output='ba'`` by ``output='sos'``) to help ensure filter stability
       and reduce numerical error.

    Examples
    --------
    iir_params can have several forms. Consider constructing a low-pass
    filter at 40 Hz with 1000 Hz sampling rate.

    In the most basic (2-parameter) form of iir_params, the order of the
    filter 'N' and the type of filtering 'ftype' are specified. To get
    coefficients for a 4th-order Butterworth filter, this would be:

    >>> iir_params = dict(order=4, ftype='butter', output='sos')  # doctest:+SKIP
    >>> iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low', return_copy=False)  # doctest:+SKIP
    >>> print((2 * len(iir_params['sos']), iir_params['padlen']))  # doctest:+SKIP
    (4, 82)

    Filters can also be constructed using filter design methods. To get a
    40 Hz Chebyshev type 1 lowpass with specific gain characteristics in the
    pass and stop bands (assuming the desired stop band is at 45 Hz), this
    would be a filter with much longer ringing:

    >>> iir_params = dict(ftype='cheby1', gpass=3, gstop=20, output='sos')  # doctest:+SKIP
    >>> iir_params = construct_iir_filter(iir_params, 40, 50, 1000, 'low')  # doctest:+SKIP
    >>> print((2 * len(iir_params['sos']), iir_params['padlen']))  # doctest:+SKIP
    (6, 439)

    Padding and/or filter coefficients can also be manually specified. For
    a 10-sample moving window with no padding during filtering, for example,
    one can just do:

    >>> iir_params = dict(b=np.ones((10)), a=[1, 0], padlen=0)  # doctest:+SKIP
    >>> iir_params = construct_iir_filter(iir_params, return_copy=False)  # doctest:+SKIP
    >>> print((iir_params['b'], iir_params['a'], iir_params['padlen']))  # doctest:+SKIP
    (array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), [1, 0], 0)

    For more information, see the tutorials
    :ref:`disc-filtering` and :ref:`tut-filter-resample`.
    """  # noqa: E501
    known_filters = (
        "bessel",
        "butter",
        "butterworth",
        "cauer",
        "cheby1",
        "cheby2",
        "chebyshev1",
        "chebyshev2",
        "chebyshevi",
        "chebyshevii",
        "ellip",
        "elliptic",
    )
    if not isinstance(iir_params, dict):
        raise TypeError(f"iir_params must be a dict, got {type(iir_params)}")
    # if the filter has been designed, we're good to go
    Wp = None
    if "sos" in iir_params:
        system = iir_params["sos"]
        output = "sos"
    elif "a" in iir_params and "b" in iir_params:
        system = (iir_params["b"], iir_params["a"])
        output = "ba"
    else:
        output = iir_params.get("output", "sos")
        _check_option("output", output, ("ba", "sos"))
        # ensure we have a valid ftype
        if "ftype" not in iir_params:
            raise RuntimeError(
                "ftype must be an entry in iir_params if 'b' and 'a' are not specified."
            )
        ftype = iir_params["ftype"]
        if ftype not in known_filters:
            raise RuntimeError(
                "ftype must be in filter_dict from scipy.signal (e.g., butter, cheby1, "
                f"etc.) not {ftype}"
            )

        # use order-based design
        f_pass = np.atleast_1d(f_pass)
        if f_pass.ndim > 1:
            raise ValueError(f"frequencies must be 1D, got {f_pass.ndim}D")
        edge_freqs = ", ".join(f"{f:0.2f}" for f in f_pass)
        Wp = f_pass / (float(sfreq) / 2)
        # IT will de designed
        ftype_nice = _ftype_dict.get(ftype, ftype)
        _validate_type(phase, str, "phase")
        _check_option("phase", phase, ("zero", "zero-double", "forward"))
        if phase in ("zero-double", "zero"):
            ptype = "zero-phase (two-pass forward and reverse) non-causal"
        else:
            ptype = "non-linear phase (one-pass forward) causal"
        logger.info("")
        logger.info("IIR filter parameters")
        logger.info("---------------------")
        logger.info(f"{ftype_nice} {btype} {ptype} filter:")
        # SciPy designs forward for -3dB, so forward-backward is -6dB
        if "order" in iir_params:
            singleton = btype in ("low", "lowpass", "high", "highpass")
            use_Wp = Wp.item() if singleton else Wp
            kwargs = dict(
                N=iir_params["order"],
                Wn=use_Wp,
                btype=btype,
                ftype=ftype,
                output=output,
            )
            for key in ("rp", "rs"):
                if key in iir_params:
                    kwargs[key] = iir_params[key]
            system = signal.iirfilter(**kwargs)
            if phase in ("zero", "zero-double"):
                ptype, pmul = "(effective, after forward-backward)", 2
            else:
                ptype, pmul = "(forward)", 1
            logger.info(
                "- Filter order %d %s", pmul * iir_params["order"] * len(Wp), ptype
            )
        else:
            # use gpass / gstop design
            Ws = np.asanyarray(f_stop) / (float(sfreq) / 2)
            if "gpass" not in iir_params or "gstop" not in iir_params:
                raise ValueError(
                    "iir_params must have at least 'gstop' and 'gpass' (or N) entries."
                )
            system = signal.iirdesign(
                Wp,
                Ws,
                iir_params["gpass"],
                iir_params["gstop"],
                ftype=ftype,
                output=output,
            )

    if system is None:
        raise RuntimeError("coefficients could not be created from iir_params")
    # do some sanity checks
    _check_coefficients(system)

    # get the gains at the cutoff frequencies
    if Wp is not None:
        if output == "sos":
            cutoffs = signal.sosfreqz(system, worN=Wp * np.pi)[1]
        else:
            cutoffs = signal.freqz(system[0], system[1], worN=Wp * np.pi)[1]
        cutoffs = 20 * np.log10(np.abs(cutoffs))
        # 2 * 20 here because we do forward-backward filtering
        if phase in ("zero", "zero-double"):
            cutoffs *= 2
        cutoffs = ", ".join([f"{c:0.2f}" for c in cutoffs])
        logger.info(f"- Cutoff{_pl(f_pass)} at {edge_freqs} Hz: {cutoffs} dB")
    # now deal with padding
    if "padlen" not in iir_params:
        padlen = estimate_ringing_samples(system)
    else:
        padlen = iir_params["padlen"]

    if return_copy:
        iir_params = deepcopy(iir_params)

    iir_params.update(dict(padlen=padlen))
    if output == "sos":
        iir_params.update(sos=system)
    else:
        iir_params.update(b=system[0], a=system[1])
    logger.info("")
    return iir_params


def _check_method(method, iir_params, extra_types=()):
    """Parse method arguments."""
    allowed_types = ["iir", "fir", "fft"] + list(extra_types)
    _validate_type(method, "str", "method")
    _check_option("method", method, allowed_types)
    if method == "fft":
        method = "fir"  # use the better name
    if method == "iir":
        if iir_params is None:
            iir_params = dict()
        if len(iir_params) == 0 or (len(iir_params) == 1 and "output" in iir_params):
            iir_params = dict(
                order=4, ftype="butter", output=iir_params.get("output", "sos")
            )
    elif iir_params is not None:
        raise ValueError('iir_params must be None if method != "iir"')
    return iir_params, method


@verbose
def filter_data(
    data,
    sfreq,
    l_freq,
    h_freq,
    picks=None,
    filter_length="auto",
    l_trans_bandwidth="auto",
    h_trans_bandwidth="auto",
    n_jobs=None,
    method="fir",
    iir_params=None,
    copy=True,
    phase="zero",
    fir_window="hamming",
    fir_design="firwin",
    pad="reflect_limited",
    *,
    verbose=None,
):
    """Filter a subset of channels.

    Parameters
    ----------
    data : ndarray, shape (..., n_times)
        The data to filter.
    sfreq : float
        The sample frequency in Hz.
    %(l_freq)s
    %(h_freq)s
    %(picks_nostr)s
        Currently this is only supported for 2D (n_channels, n_times) and
        3D (n_epochs, n_channels, n_times) arrays.
    %(filter_length)s
    %(l_trans_bandwidth)s
    %(h_trans_bandwidth)s
    %(n_jobs_fir)s
    %(method_fir)s
    %(iir_params)s
    copy : bool
        If True, a copy of x, filtered, is returned. Otherwise, it operates
        on x in place.
    %(phase)s
    %(fir_window)s
    %(fir_design)s
    %(pad_fir)s
        The default is ``'reflect_limited'``.

        .. versionadded:: 0.15
    %(verbose)s

    Returns
    -------
    data : ndarray, shape (..., n_times)
        The filtered data.

    See Also
    --------
    construct_iir_filter
    create_filter
    mne.io.Raw.filter
    notch_filter
    resample

    Notes
    -----
    Applies a zero-phase low-pass, high-pass, band-pass, or band-stop
    filter to the channels selected by ``picks``.

    ``l_freq`` and ``h_freq`` are the frequencies below which and above
    which, respectively, to filter out of the data. Thus the uses are:

        * ``l_freq < h_freq``: band-pass filter
        * ``l_freq > h_freq``: band-stop filter
        * ``l_freq is not None and h_freq is None``: high-pass filter
        * ``l_freq is None and h_freq is not None``: low-pass filter

    .. note:: If n_jobs > 1, more memory is required as
              ``len(picks) * n_times`` additional time points need to
              be temporarily stored in memory.

    For more information, see the tutorials
    :ref:`disc-filtering` and :ref:`tut-filter-resample` and
    :func:`mne.filter.create_filter`.
    """
    data = _check_filterable(data)
    iir_params, method = _check_method(method, iir_params)
    filt = create_filter(
        data,
        sfreq,
        l_freq,
        h_freq,
        filter_length,
        l_trans_bandwidth,
        h_trans_bandwidth,
        method,
        iir_params,
        phase,
        fir_window,
        fir_design,
    )
    if method in ("fir", "fft"):
        data = _overlap_add_filter(data, filt, None, phase, picks, n_jobs, copy, pad)
    else:
        data = _iir_filter(data, filt, picks, n_jobs, copy, phase)
    return data


@verbose
def create_filter(
    data,
    sfreq,
    l_freq,
    h_freq,
    filter_length="auto",
    l_trans_bandwidth="auto",
    h_trans_bandwidth="auto",
    method="fir",
    iir_params=None,
    phase="zero",
    fir_window="hamming",
    fir_design="firwin",
    verbose=None,
):
    r"""Create a FIR or IIR filter.

    ``l_freq`` and ``h_freq`` are the frequencies below which and above
    which, respectively, to filter out of the data. Thus the uses are:

        * ``l_freq < h_freq``: band-pass filter
        * ``l_freq > h_freq``: band-stop filter
        * ``l_freq is not None and h_freq is None``: high-pass filter
        * ``l_freq is None and h_freq is not None``: low-pass filter

    Parameters
    ----------
    data : ndarray, shape (..., n_times) | None
        The data that will be filtered. This is used for sanity checking
        only. If None, no sanity checking related to the length of the signal
        relative to the filter order will be performed.
    sfreq : float
        The sample frequency in Hz.
    %(l_freq)s
    %(h_freq)s
    %(filter_length)s
    %(l_trans_bandwidth)s
    %(h_trans_bandwidth)s
    %(method_fir)s
    %(iir_params)s
    %(phase)s
    %(fir_window)s
    %(fir_design)s
    %(verbose)s

    Returns
    -------
    filt : array or dict
        Will be an array of FIR coefficients for method='fir', and dict
        with IIR parameters for method='iir'.

    See Also
    --------
    filter_data

    Notes
    -----
    .. note:: For FIR filters, the *cutoff frequency*, i.e. the -6 dB point,
              is in the middle of the transition band (when using phase='zero'
              and fir_design='firwin'). For IIR filters, the cutoff frequency
              is given by ``l_freq`` or ``h_freq`` directly, and
              ``l_trans_bandwidth`` and ``h_trans_bandwidth`` are ignored.

    **Band-pass filter**

    The frequency response is (approximately) given by::

       1-|               ----------
         |             /|         | \
     |H| |            / |         |  \
         |           /  |         |   \
         |          /   |         |    \
       0-|----------    |         |     --------------
         |         |    |         |     |            |
         0        Fs1  Fp1       Fp2   Fs2          Nyq

    Where:

        * Fs1 = Fp1 - l_trans_bandwidth in Hz
        * Fs2 = Fp2 + h_trans_bandwidth in Hz

    **Band-stop filter**

    The frequency response is (approximately) given by::

        1-|---------                   ----------
          |         \                 /
      |H| |          \               /
          |           \             /
          |            \           /
        0-|             -----------
          |        |    |         |    |        |
          0       Fp1  Fs1       Fs2  Fp2      Nyq

    Where ``Fs1 = Fp1 + l_trans_bandwidth`` and
    ``Fs2 = Fp2 - h_trans_bandwidth``.

    Multiple stop bands can be specified using arrays.

    **Low-pass filter**

    The frequency response is (approximately) given by::

        1-|------------------------
          |                        \
      |H| |                         \
          |                          \
          |                           \
        0-|                            ----------------
          |                       |    |              |
          0                      Fp  Fstop           Nyq

    Where ``Fstop = Fp + trans_bandwidth``.

    **High-pass filter**

    The frequency response is (approximately) given by::

        1-|             -----------------------
          |            /
      |H| |           /
          |          /
          |         /
        0-|---------
          |        |    |                     |
          0      Fstop  Fp                   Nyq

    Where ``Fstop = Fp - trans_bandwidth``.

    .. versionadded:: 0.14
    """
    sfreq = float(sfreq)
    if sfreq < 0:
        raise ValueError("sfreq must be positive")
    # If no data specified, sanity checking will be skipped
    if data is None:
        logger.info(
            "No data specified. Sanity checks related to the length of the signal "
            "relative to the filter order will be skipped."
        )
    if h_freq is not None:
        h_freq = np.array(h_freq, float).ravel()
        if (h_freq > (sfreq / 2.0)).any():
            raise ValueError(
                f"h_freq ({h_freq}) must be less than the Nyquist frequency "
                f"{sfreq / 2.0}"
            )
    if l_freq is not None:
        l_freq = np.array(l_freq, float).ravel()
        if (l_freq == 0).all():
            l_freq = None
    iir_params, method = _check_method(method, iir_params)
    if l_freq is None and h_freq is None:
        (
            data,
            sfreq,
            _,
            _,
            _,
            _,
            filter_length,
            phase,
            fir_window,
            fir_design,
        ) = _triage_filter_params(
            data,
            sfreq,
            None,
            None,
            None,
            None,
            filter_length,
            method,
            phase,
            fir_window,
            fir_design,
        )
        if method == "iir":
            out = dict() if iir_params is None else deepcopy(iir_params)
            out.update(b=np.array([1.0]), a=np.array([1.0]))
        else:
            freq = [0, sfreq / 2.0]
            gain = [1.0, 1.0]
    if l_freq is None and h_freq is not None:
        h_freq = h_freq.item()
        logger.info(f"Setting up low-pass filter at {h_freq:0.2g} Hz")
        (
            data,
            sfreq,
            _,
            f_p,
            _,
            f_s,
            filter_length,
            phase,
            fir_window,
            fir_design,
        ) = _triage_filter_params(
            data,
            sfreq,
            None,
            h_freq,
            None,
            h_trans_bandwidth,
            filter_length,
            method,
            phase,
            fir_window,
            fir_design,
        )
        if method == "iir":
            out = construct_iir_filter(
                iir_params, f_p, f_s, sfreq, "lowpass", phase=phase
            )
        else:  # 'fir'
            freq = [0, f_p, f_s]
            gain = [1, 1, 0]
            if f_s != sfreq / 2.0:
                freq += [sfreq / 2.0]
                gain += [0]
    elif l_freq is not None and h_freq is None:
        l_freq = l_freq.item()
        logger.info(f"Setting up high-pass filter at {l_freq:0.2g} Hz")
        (
            data,
            sfreq,
            pass_,
            _,
            stop,
            _,
            filter_length,
            phase,
            fir_window,
            fir_design,
        ) = _triage_filter_params(
            data,
            sfreq,
            l_freq,
            None,
            l_trans_bandwidth,
            None,
            filter_length,
            method,
            phase,
            fir_window,
            fir_design,
        )
        if method == "iir":
            out = construct_iir_filter(
                iir_params, pass_, stop, sfreq, "highpass", phase=phase
            )
        else:  # 'fir'
            freq = [stop, pass_, sfreq / 2.0]
            gain = [0, 1, 1]
            if stop != 0:
                freq = [0] + freq
                gain = [0] + gain
    elif l_freq is not None and h_freq is not None:
        if (l_freq < h_freq).any():
            l_freq, h_freq = l_freq.item(), h_freq.item()
            logger.info(
                f"Setting up band-pass filter from {l_freq:0.2g} - {h_freq:0.2g} Hz"
            )
            (
                data,
                sfreq,
                f_p1,
                f_p2,
                f_s1,
                f_s2,
                filter_length,
                phase,
                fir_window,
                fir_design,
            ) = _triage_filter_params(
                data,
                sfreq,
                l_freq,
                h_freq,
                l_trans_bandwidth,
                h_trans_bandwidth,
                filter_length,
                method,
                phase,
                fir_window,
                fir_design,
            )
            if method == "iir":
                out = construct_iir_filter(
                    iir_params,
                    [f_p1, f_p2],
                    [f_s1, f_s2],
                    sfreq,
                    "bandpass",
                    phase=phase,
                )
            else:  # 'fir'
                freq = [f_s1, f_p1, f_p2, f_s2]
                gain = [0, 1, 1, 0]
                if f_s2 != sfreq / 2.0:
                    freq += [sfreq / 2.0]
                    gain += [0]
                if f_s1 != 0:
                    freq = [0] + freq
                    gain = [0] + gain
        else:
            # This could possibly be removed after 0.14 release, but might
            # as well leave it in to sanity check notch_filter
            if len(l_freq) != len(h_freq):
                raise ValueError("l_freq and h_freq must be the same length")
            msg = "Setting up band-stop filter"
            if len(l_freq) == 1:
                l_freq, h_freq = l_freq.item(), h_freq.item()
                msg += f" from {h_freq:0.2g} - {l_freq:0.2g} Hz"
            logger.info(msg)
            # Note: order of outputs is intentionally switched here!
            (
                data,
                sfreq,
                f_s1,
                f_s2,
                f_p1,
                f_p2,
                filter_length,
                phase,
                fir_window,
                fir_design,
            ) = _triage_filter_params(
                data,
                sfreq,
                h_freq,
                l_freq,
                h_trans_bandwidth,
                l_trans_bandwidth,
                filter_length,
                method,
                phase,
                fir_window,
                fir_design,
                bands="arr",
                reverse=True,
            )
            if method == "iir":
                if len(f_p1) != 1:
                    raise ValueError(
                        "Multiple stop-bands can only be used with method='fir' "
                        "and method='spectrum_fit'"
                    )
                out = construct_iir_filter(
                    iir_params,
                    [f_p1[0], f_p2[0]],
                    [f_s1[0], f_s2[0]],
                    sfreq,
                    "bandstop",
                    phase=phase,
                )
            else:  # 'fir'
                freq = np.r_[f_p1, f_s1, f_s2, f_p2]
                gain = np.r_[
                    np.ones_like(f_p1),
                    np.zeros_like(f_s1),
                    np.zeros_like(f_s2),
                    np.ones_like(f_p2),
                ]
                order = np.argsort(freq)
                freq = freq[order]
                gain = gain[order]
                if freq[0] != 0:
                    freq = np.r_[[0.0], freq]
                    gain = np.r_[[1.0], gain]
                if freq[-1] != sfreq / 2.0:
                    freq = np.r_[freq, [sfreq / 2.0]]
                    gain = np.r_[gain, [1.0]]
                if np.any(np.abs(np.diff(gain, 2)) > 1):
                    raise ValueError("Stop bands are not sufficiently separated.")
    if method == "fir":
        out = _construct_fir_filter(
            sfreq, freq, gain, filter_length, phase, fir_window, fir_design
        )
    return out


@verbose
def notch_filter(
    x,
    Fs,
    freqs,
    filter_length="auto",
    notch_widths=None,
    trans_bandwidth=1,
    method="fir",
    iir_params=None,
    mt_bandwidth=None,
    p_value=0.05,
    picks=None,
    n_jobs=None,
    copy=True,
    phase="zero",
    fir_window="hamming",
    fir_design="firwin",
    pad="reflect_limited",
    *,
    verbose=None,
):
    r"""Notch filter for the signal x.

    Applies a zero-phase notch filter to the signal x, operating on the last
    dimension.

    Parameters
    ----------
    x : array
        Signal to filter.
    Fs : float
        Sampling rate in Hz.
    freqs : float | array of float | None
        Frequencies to notch filter in Hz, e.g. np.arange(60, 241, 60).
        Multiple stop-bands can only be used with method='fir'
        and method='spectrum_fit'. None can only be used with the mode
        'spectrum_fit', where an F test is used to find sinusoidal components.
    %(filter_length_notch)s
    notch_widths : float | array of float | None
        Width of the stop band (centred at each freq in freqs) in Hz.
        If None, freqs / 200 is used.
    trans_bandwidth : float
        Width of the transition band in Hz.
        Only used for ``method='fir'`` and ``method='iir'``.
    %(method_fir)s
        'spectrum_fit' will use multi-taper estimation of sinusoidal
        components. If freqs=None and method='spectrum_fit', significant
        sinusoidal components are detected using an F test, and noted by
        logging.
    %(iir_params)s
    mt_bandwidth : float | None
        The bandwidth of the multitaper windowing function in Hz.
        Only used in 'spectrum_fit' mode.
    p_value : float
        P-value to use in F-test thresholding to determine significant
        sinusoidal components to remove when method='spectrum_fit' and
        freqs=None. Note that this will be Bonferroni corrected for the
        number of frequencies, so large p-values may be justified.
    %(picks_nostr)s
        Only supported for 2D (n_channels, n_times) and 3D
        (n_epochs, n_channels, n_times) data.
    %(n_jobs_fir)s
    copy : bool
        If True, a copy of x, filtered, is returned. Otherwise, it operates
        on x in place.
    %(phase)s
    %(fir_window)s
    %(fir_design)s
    %(pad_fir)s
        The default is ``'reflect_limited'``.
    %(verbose)s

    Returns
    -------
    xf : array
        The x array filtered.

    See Also
    --------
    filter_data
    resample

    Notes
    -----
    The frequency response is (approximately) given by::

        1-|----------         -----------
          |          \       /
      |H| |           \     /
          |            \   /
          |             \ /
        0-|              -
          |         |    |    |         |
          0        Fp1 freq  Fp2       Nyq

    For each freq in freqs, where ``Fp1 = freq - trans_bandwidth / 2`` and
    ``Fs2 = freq + trans_bandwidth / 2``.

    References
    ----------
    Multi-taper removal is inspired by code from the Chronux toolbox, see
    www.chronux.org and the book "Observed Brain Dynamics" by Partha Mitra
    & Hemant Bokil, Oxford University Press, New York, 2008. Please
    cite this in publications if method 'spectrum_fit' is used.
    """
    x = _check_filterable(x, "notch filtered", "notch_filter")
    iir_params, method = _check_method(method, iir_params, ["spectrum_fit"])

    if freqs is not None:
        freqs = np.atleast_1d(freqs)
    elif method != "spectrum_fit":
        raise ValueError("freqs=None can only be used with method spectrum_fit")

    # Only have to deal with notch_widths for non-autodetect
    if freqs is not None:
        if notch_widths is None:
            notch_widths = freqs / 200.0
        elif np.any(notch_widths < 0):
            raise ValueError("notch_widths must be >= 0")
        else:
            notch_widths = np.atleast_1d(notch_widths)
            if len(notch_widths) == 1:
                notch_widths = notch_widths[0] * np.ones_like(freqs)
            elif len(notch_widths) != len(freqs):
                raise ValueError(
                    "notch_widths must be None, scalar, or the same length as freqs"
                )

    if method in ("fir", "iir"):
        # Speed this up by computing the fourier coefficients once
        tb_2 = trans_bandwidth / 2.0
        lows = [freq - nw / 2.0 - tb_2 for freq, nw in zip(freqs, notch_widths)]
        highs = [freq + nw / 2.0 + tb_2 for freq, nw in zip(freqs, notch_widths)]
        xf = filter_data(
            x,
            Fs,
            highs,
            lows,
            picks,
            filter_length,
            tb_2,
            tb_2,
            n_jobs,
            method,
            iir_params,
            copy,
            phase,
            fir_window,
            fir_design,
            pad=pad,
        )
    elif method == "spectrum_fit":
        xf = _mt_spectrum_proc(
            x,
            Fs,
            freqs,
            notch_widths,
            mt_bandwidth,
            p_value,
            picks,
            n_jobs,
            copy,
            filter_length,
        )

    return xf


def _get_window_thresh(n_times, sfreq, mt_bandwidth, p_value):
    from .time_frequency.multitaper import _compute_mt_params

    # figure out what tapers to use
    window_fun, _, _ = _compute_mt_params(
        n_times, sfreq, mt_bandwidth, False, False, verbose=False
    )

    # F-stat of 1-p point
    threshold = fstat.ppf(1 - p_value / n_times, 2, 2 * len(window_fun) - 2)
    return window_fun, threshold


def _mt_spectrum_proc(
    x,
    sfreq,
    line_freqs,
    notch_widths,
    mt_bandwidth,
    p_value,
    picks,
    n_jobs,
    copy,
    filter_length,
):
    """Call _mt_spectrum_remove."""
    # set up array for filtering, reshape to 2D, operate on last axis
    x, orig_shape, picks = _prep_for_filtering(x, copy, picks)
    if isinstance(filter_length, str) and filter_length == "auto":
        filter_length = "10s"
    if filter_length is None:
        filter_length = x.shape[-1]
    filter_length = min(_to_samples(filter_length, sfreq, "", ""), x.shape[-1])
    get_wt = partial(
        _get_window_thresh, sfreq=sfreq, mt_bandwidth=mt_bandwidth, p_value=p_value
    )
    window_fun, threshold = get_wt(filter_length)
    parallel, p_fun, n_jobs = parallel_func(_mt_spectrum_remove_win, n_jobs)
    if n_jobs == 1:
        freq_list = list()
        for ii, x_ in enumerate(x):
            if ii in picks:
                x[ii], f = _mt_spectrum_remove_win(
                    x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_wt
                )
                freq_list.append(f)
    else:
        data_new = parallel(
            p_fun(x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_wt)
            for xi, x_ in enumerate(x)
            if xi in picks
        )
        freq_list = [d[1] for d in data_new]
        data_new = np.array([d[0] for d in data_new])
        x[picks, :] = data_new

    # report found frequencies, but do some sanitizing first by binning into
    # 1 Hz bins
    counts = Counter(
        sum((np.unique(np.round(ff)).tolist() for f in freq_list for ff in f), list())
    )
    kind = "Detected" if line_freqs is None else "Removed"
    found_freqs = (
        "\n".join(
            f"    {freq:6.2f} : {counts[freq]:4d} window{_pl(counts[freq])}"
            for freq in sorted(counts)
        )
        or "    None"
    )
    logger.info(f"{kind} notch frequencies (Hz):\n{found_freqs}")

    x.shape = orig_shape
    return x


def _mt_spectrum_remove_win(
    x, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh
):
    n_times = x.shape[-1]
    n_samples = window_fun.shape[1]
    n_overlap = (n_samples + 1) // 2
    x_out = np.zeros_like(x)
    rm_freqs = list()

    # Define how to process a chunk of data
    def process(x_, *, start, stop):
        out = _mt_spectrum_remove(
            x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh
        )
        rm_freqs.append(out[1])
        return (out[0],)  # must return a tuple

    _COLA(process, x_out, n_times, n_samples, n_overlap, sfreq, verbose=False).feed(x)
    return x_out, rm_freqs


def _mt_spectrum_remove(
    x, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh
):
    """Use MT-spectrum to remove line frequencies.

    Based on Chronux. If line_freqs is specified, all freqs within notch_width
    of each line_freq is set to zero.
    """
    from .time_frequency.multitaper import _mt_spectra

    assert x.ndim == 1
    if x.shape[-1] != window_fun.shape[-1]:
        window_fun, threshold = get_thresh(x.shape[-1])
    # drop the even tapers
    n_tapers = len(window_fun)
    tapers_odd = np.arange(0, n_tapers, 2)
    tapers_even = np.arange(1, n_tapers, 2)
    tapers_use = window_fun[tapers_odd]

    # sum tapers for (used) odd prolates across time (n_tapers, 1)
    H0 = np.sum(tapers_use, axis=1)

    # sum of squares across tapers (1, )
    H0_sq = sum_squared(H0)

    # make "time" vector
    rads = 2 * np.pi * (np.arange(x.size) / float(sfreq))

    # compute mt_spectrum (returning n_ch, n_tapers, n_freq)
    x_p, freqs = _mt_spectra(x[np.newaxis, :], window_fun, sfreq)

    # sum of the product of x_p and H0 across tapers (1, n_freqs)
    x_p_H0 = np.sum(x_p[:, tapers_odd, :] * H0[np.newaxis, :, np.newaxis], axis=1)

    # resulting calculated amplitudes for all freqs
    A = x_p_H0 / H0_sq

    if line_freqs is None:
        # figure out which freqs to remove using F stat

        # estimated coefficient
        x_hat = A * H0[:, np.newaxis]

        # numerator for F-statistic
        num = (n_tapers - 1) * (A * A.conj()).real * H0_sq
        # denominator for F-statistic
        den = np.sum(np.abs(x_p[:, tapers_odd, :] - x_hat) ** 2, 1) + np.sum(
            np.abs(x_p[:, tapers_even, :]) ** 2, 1
        )
        den[den == 0] = np.inf
        f_stat = num / den

        # find frequencies to remove
        indices = np.where(f_stat > threshold)[1]
        rm_freqs = freqs[indices]
    else:
        # specify frequencies
        indices_1 = np.unique([np.argmin(np.abs(freqs - lf)) for lf in line_freqs])
        indices_2 = [
            np.logical_and(freqs > lf - nw / 2.0, freqs < lf + nw / 2.0)
            for lf, nw in zip(line_freqs, notch_widths)
        ]
        indices_2 = np.where(np.any(np.array(indices_2), axis=0))[0]
        indices = np.unique(np.r_[indices_1, indices_2])
        rm_freqs = freqs[indices]

    fits = list()
    for ind in indices:
        c = 2 * A[0, ind]
        fit = np.abs(c) * np.cos(freqs[ind] * rads + np.angle(c))
        fits.append(fit)

    if len(fits) == 0:
        datafit = 0.0
    else:
        # fitted sinusoids are summed, and subtracted from data
        datafit = np.sum(fits, axis=0)

    return x - datafit, rm_freqs


def _check_filterable(x, kind="filtered", alternative="filter"):
    # Let's be fairly strict about this -- users can easily coerce to ndarray
    # at their end, and we already should do it internally any time we are
    # using these low-level functions. At the same time, let's
    # help people who might accidentally use low-level functions that they
    # shouldn't use by pushing them in the right direction
    from .epochs import BaseEpochs
    from .evoked import Evoked
    from .io import BaseRaw

    if isinstance(x, BaseRaw | BaseEpochs | Evoked):
        try:
            name = x.__class__.__name__
        except Exception:
            pass
        else:
            raise TypeError(
                "This low-level function only operates on np.ndarray instances. To get "
                f"a {kind} {name} instance, use a method like `inst_new = inst.copy()."
                f"{alternative}(...)` instead."
            )
    _validate_type(x, (np.ndarray, list, tuple), f"Data to be {kind}")
    x = np.asanyarray(x)
    if x.dtype != np.float64:
        raise ValueError(f"Data to be {kind} must be real floating, got {x.dtype}")
    return x


def _resamp_ratio_len(up, down, n):
    ratio = float(up) / down
    return ratio, max(int(round(ratio * n)), 1)


@verbose
def resample(
    x,
    up=1.0,
    down=1.0,
    *,
    axis=-1,
    window="auto",
    n_jobs=None,
    pad="auto",
    npad=100,
    method="fft",
    verbose=None,
):
    """Resample an array.

    Operates along the last dimension of the array.

    Parameters
    ----------
    x : ndarray
        Signal to resample.
    up : float
        Factor to upsample by.
    down : float
        Factor to downsample by.
    axis : int
        Axis along which to resample (default is the last axis).
    %(window_resample)s
    %(n_jobs_cuda)s
        ``n_jobs='cuda'`` is only supported when ``method="fft"``.
    %(pad_resample_auto)s

        .. versionadded:: 0.15
    %(npad_resample)s
    %(method_resample)s

        .. versionadded:: 1.7
    %(verbose)s

    Returns
    -------
    y : array
        The x array resampled.

    Notes
    -----
    When using ``method="fft"`` (default),
    this uses (hopefully) intelligent edge padding and frequency-domain
    windowing improve :func:`scipy.signal.resample`'s resampling method, which
    we have adapted for our use here. Choices of npad and window have
    important consequences, and the default choices should work well
    for most natural signals.
    """
    _validate_type(method, str, "method")
    _validate_type(pad, str, "pad")
    _check_option("method", method, ("fft", "polyphase"))

    # make sure our arithmetic will work
    x = _check_filterable(x, "resampled", "resample")
    ratio, final_len = _resamp_ratio_len(up, down, x.shape[axis])
    del up, down
    if axis < 0:
        axis = x.ndim + axis
    if x.shape[axis] == 0:
        warn(f"x has zero length along axis={axis}, returning a copy of x")
        return x.copy()

    # prep for resampling along the last axis (swap axis with last then reshape)
    out_shape = list(x.shape)
    out_shape.pop(axis)
    out_shape.append(final_len)
    x = np.atleast_2d(x.swapaxes(axis, -1).reshape((-1, x.shape[axis])))

    # do the resampling using FFT or polyphase methods
    kwargs = dict(pad=pad, window=window, n_jobs=n_jobs)
    if method == "fft":
        y = _resample_fft(x, npad=npad, ratio=ratio, final_len=final_len, **kwargs)
    else:
        up, down, kwargs["window"] = _prep_polyphase(
            ratio, x.shape[-1], final_len, window
        )
        half_len = len(window) // 2
        logger.info(
            f"Polyphase resampling neighborhood: ±{half_len} "
            f"input sample{_pl(half_len)}"
        )
        y = _resample_polyphase(x, up=up, down=down, **kwargs)
    assert y.shape[-1] == final_len

    # restore dimensions (reshape then swap axis with last)
    y = y.reshape(out_shape).swapaxes(axis, -1)

    return y


def _prep_polyphase(ratio, x_len, final_len, window):
    if isinstance(window, str) and window == "auto":
        window = ("kaiser", 5.0)  # SciPy default
    up = final_len
    down = x_len
    g_ = gcd(up, down)
    up = up // g_
    down = down // g_
    # Figure out our signal neighborhood and design window (adapted from SciPy)
    if not isinstance(window, list | np.ndarray):
        # Design a linear-phase low-pass FIR filter
        max_rate = max(up, down)
        f_c = 1.0 / max_rate  # cutoff of FIR filter (rel. to Nyquist)
        half_len = 10 * max_rate  # reasonable cutoff for sinc-like function
        window = signal.firwin(2 * half_len + 1, f_c, window=window)
    return up, down, window


def _resample_polyphase(x, *, up, down, pad, window, n_jobs):
    if pad == "auto":
        pad = "reflect"
    kwargs = dict(padtype=pad, window=window, up=up, down=down)
    _validate_type(
        n_jobs, (None, "int-like"), "n_jobs", extra="when method='polyphase'"
    )
    parallel, p_fun, n_jobs = parallel_func(signal.resample_poly, n_jobs)
    if n_jobs == 1:
        y = signal.resample_poly(x, axis=-1, **kwargs)
    else:
        y = np.array(parallel(p_fun(x_, **kwargs) for x_ in x))
    return y


def _resample_fft(x_flat, *, ratio, final_len, pad, window, npad, n_jobs):
    x_len = x_flat.shape[-1]
    pad = "reflect_limited" if pad == "auto" else pad
    if (isinstance(window, str) and window == "auto") or window is None:
        window = "boxcar"
    if isinstance(npad, str):
        _check_option("npad", npad, ("auto",), extra="when a string")
        # Figure out reasonable pad that gets us to a power of 2
        min_add = min(x_len // 8, 100) * 2
        npad = 2 ** int(np.ceil(np.log2(x_len + min_add))) - x_len
        npad, extra = divmod(npad, 2)
        npads = np.array([npad, npad + extra], int)
    else:
        npad = _ensure_int(npad, "npad", extra="or 'auto'")
        npads = np.array([npad, npad], int)
    del npad

    # prep for resampling now
    orig_len = x_len + npads.sum()  # length after padding
    new_len = max(int(round(ratio * orig_len)), 1)  # length after resampling
    to_removes = [int(round(ratio * npads[0]))]
    to_removes.append(new_len - final_len - to_removes[0])
    to_removes = np.array(to_removes)
    # This should hold:
    # assert np.abs(to_removes[1] - to_removes[0]) <= int(np.ceil(ratio))

    # figure out windowing function
    if callable(window):
        W = window(fft.fftfreq(orig_len))
    elif isinstance(window, np.ndarray) and window.shape == (orig_len,):
        W = window
    else:
        W = fft.ifftshift(signal.get_window(window, orig_len))
    W *= float(new_len) / float(orig_len)

    # figure out if we should use CUDA
    n_jobs, cuda_dict = _setup_cuda_fft_resample(n_jobs, W, new_len)

    # do the resampling using an adaptation of scipy's FFT-based resample()
    # use of the 'flat' window is recommended for minimal ringing
    parallel, p_fun, n_jobs = parallel_func(_fft_resample, n_jobs)
    if n_jobs == 1:
        y = np.zeros((len(x_flat), new_len - to_removes.sum()), dtype=x_flat.dtype)
        for xi, x_ in enumerate(x_flat):
            y[xi] = _fft_resample(x_, new_len, npads, to_removes, cuda_dict, pad)
    else:
        y = parallel(
            p_fun(x_, new_len, npads, to_removes, cuda_dict, pad) for x_ in x_flat
        )
        y = np.array(y)

    return y


def _resample_stim_channels(stim_data, up, down):
    """Resample stim channels, carefully.

    Parameters
    ----------
    stim_data : array, shape (n_samples,) or (n_stim_channels, n_samples)
        Stim channels to resample.
    up : float
        Factor to upsample by.
    down : float
        Factor to downsample by.

    Returns
    -------
    stim_resampled : array, shape (n_stim_channels, n_samples_resampled)
        The resampled stim channels.

    Note
    ----
    The approach taken here is equivalent to the approach in the C-code.
    See the decimate_stimch function in MNE/mne_browse_raw/save.c
    """
    stim_data = np.atleast_2d(stim_data)
    n_stim_channels, n_samples = stim_data.shape

    ratio = float(up) / down
    resampled_n_samples = int(round(n_samples * ratio))

    stim_resampled = np.zeros((n_stim_channels, resampled_n_samples))

    # Figure out which points in old data to subsample protect against
    # out-of-bounds, which can happen (having one sample more than
    # expected) due to padding
    sample_picks = np.minimum(
        (np.arange(resampled_n_samples) / ratio).astype(int), n_samples - 1
    )

    # Create windows starting from sample_picks[i], ending at sample_picks[i+1]
    windows = zip(sample_picks, np.r_[sample_picks[1:], n_samples])

    # Use the first non-zero value in each window
    for window_i, window in enumerate(windows):
        for stim_num, stim in enumerate(stim_data):
            nonzero = stim[window[0] : window[1]].nonzero()[0]
            if len(nonzero) > 0:
                val = stim[window[0] + nonzero[0]]
            else:
                val = stim[window[0]]
            stim_resampled[stim_num, window_i] = val

    return stim_resampled


def detrend(x, order=1, axis=-1):
    """Detrend the array x.

    Parameters
    ----------
    x : n-d array
        Signal to detrend.
    order : int
        Fit order. Currently must be '0' or '1'.
    axis : int
        Axis of the array to operate on.

    Returns
    -------
    y : array
        The x array detrended.

    Examples
    --------
    As in :func:`scipy.signal.detrend`::

        >>> randgen = np.random.RandomState(9)
        >>> npoints = int(1e3)
        >>> noise = randgen.randn(npoints)
        >>> x = 3 + 2*np.linspace(0, 1, npoints) + noise
        >>> bool((detrend(x) - noise).max() < 0.01)
        True
    """
    if axis > len(x.shape):
        raise ValueError(f"x does not have {axis} axes")
    if order == 0:
        fit = "constant"
    elif order == 1:
        fit = "linear"
    else:
        raise ValueError("order must be 0 or 1")

    y = signal.detrend(x, axis=axis, type=fit)

    return y


# Taken from Ifeachor and Jervis p. 356.
# Note that here the passband ripple and stopband attenuation are
# rendundant. The scalar passband ripple δp is expressed in dB as
# 20 * log10(1+δp), but the scalar stopband ripple δs is expressed in dB as
# -20 * log10(δs). So if we know that our stopband attenuation is 53 dB
# (Hamming) then δs = 10 ** (53 / -20.), which means that the passband
# deviation should be 20 * np.log10(1 + 10 ** (53 / -20.)) == 0.0194.
_fir_window_dict = {
    "hann": dict(name="Hann", ripple=0.0546, attenuation=44),
    "hamming": dict(name="Hamming", ripple=0.0194, attenuation=53),
    "blackman": dict(name="Blackman", ripple=0.0017, attenuation=74),
}
_known_fir_windows = tuple(sorted(_fir_window_dict.keys()))
_known_phases_fir = ("linear", "zero", "zero-double", "minimum", "minimum-half")
_known_phases_iir = ("zero", "zero-double", "forward")
_known_fir_designs = ("firwin", "firwin2")
_fir_design_dict = {
    "firwin": "Windowed time-domain",
    "firwin2": "Windowed frequency-domain",
}


def _to_samples(filter_length, sfreq, phase, fir_design):
    _validate_type(filter_length, (str, "int-like"), "filter_length")
    if isinstance(filter_length, str):
        filter_length = filter_length.lower()
        err_msg = (
            "filter_length, if a string, must be a "
            'human-readable time, e.g. "10s", or "auto", not '
            f'"{filter_length}"'
        )
        if filter_length.lower().endswith("ms"):
            mult_fact = 1e-3
            filter_length = filter_length[:-2]
        elif filter_length[-1].lower() == "s":
            mult_fact = 1
            filter_length = filter_length[:-1]
        else:
            raise ValueError(err_msg)
        # now get the number
        try:
            filter_length = float(filter_length)
        except ValueError:
            raise ValueError(err_msg)
        filter_length = max(int(np.ceil(filter_length * mult_fact * sfreq)), 1)
        if fir_design == "firwin":
            filter_length += (filter_length - 1) % 2
    filter_length = _ensure_int(filter_length, "filter_length")
    return filter_length


def _triage_filter_params(
    x,
    sfreq,
    l_freq,
    h_freq,
    l_trans_bandwidth,
    h_trans_bandwidth,
    filter_length,
    method,
    phase,
    fir_window,
    fir_design,
    bands="scalar",
    reverse=False,
):
    """Validate and automate filter parameter selection."""
    _validate_type(phase, "str", "phase")
    if method == "fir":
        _check_option("phase", phase, _known_phases_fir, extra="when FIR filtering")
    else:
        _check_option("phase", phase, _known_phases_iir, extra="when IIR filtering")
    _validate_type(fir_window, "str", "fir_window")
    _check_option("fir_window", fir_window, _known_fir_windows)
    _validate_type(fir_design, "str", "fir_design")
    _check_option("fir_design", fir_design, _known_fir_designs)

    # Helpers for reporting
    report_phase = "non-linear phase" if phase == "minimum" else "zero-phase"
    causality = "causal" if phase == "minimum" else "non-causal"
    if phase == "zero-double":
        report_pass = "two-pass forward and reverse"
    else:
        report_pass = "one-pass"
    if l_freq is not None:
        if h_freq is not None:
            kind = "bandstop" if reverse else "bandpass"
        else:
            kind = "highpass"
            assert not reverse
    elif h_freq is not None:
        kind = "lowpass"
        assert not reverse
    else:
        kind = "allpass"

    def float_array(c):
        return np.array(c, float).ravel()

    if bands == "arr":
        cast = float_array
    else:
        cast = float
    sfreq = float(sfreq)
    if l_freq is not None:
        l_freq = cast(l_freq)
        if np.any(l_freq <= 0):
            raise ValueError(f"highpass frequency {l_freq} must be greater than zero")
    if h_freq is not None:
        h_freq = cast(h_freq)
        if np.any(h_freq >= sfreq / 2.0):
            raise ValueError(
                f"lowpass frequency {h_freq} must be less than Nyquist ({sfreq / 2.0})"
            )

    dB_cutoff = False  # meaning, don't try to compute or report
    if bands == "scalar" or (len(h_freq) == 1 and len(l_freq) == 1):
        if phase == "zero":
            dB_cutoff = "-6 dB"
        elif phase == "zero-double":
            dB_cutoff = "-12 dB"

    # we go to the next power of two when in FIR and zero-double mode
    if method == "iir":
        # Ignore these parameters, effectively
        l_stop, h_stop = l_freq, h_freq
    else:  # method == 'fir'
        l_stop = h_stop = None
        logger.info("")
        logger.info("FIR filter parameters")
        logger.info("---------------------")
        logger.info(
            f"Designing a {report_pass}, {report_phase}, {causality} {kind} filter:"
        )
        logger.info(f"- {_fir_design_dict[fir_design]} design ({fir_design}) method")
        this_dict = _fir_window_dict[fir_window]
        if fir_design == "firwin":
            logger.info(
                "- {name:s} window with {ripple:0.4f} passband ripple "
                "and {attenuation:d} dB stopband attenuation".format(**this_dict)
            )
        else:
            logger.info("- {name:s} window".format(**this_dict))

        if l_freq is not None:  # high-pass component
            if isinstance(l_trans_bandwidth, str):
                if l_trans_bandwidth != "auto":
                    raise ValueError(
                        'l_trans_bandwidth must be "auto" if string, got "'
                        f'{l_trans_bandwidth}"'
                    )
                l_trans_bandwidth = np.minimum(np.maximum(0.25 * l_freq, 2.0), l_freq)
            l_trans_rep = np.array(l_trans_bandwidth, float)
            if l_trans_rep.size == 1:
                l_trans_rep = f"{l_trans_rep.item():0.2f}"
            with np.printoptions(precision=2, floatmode="fixed"):
                msg = f"- Lower transition bandwidth: {l_trans_rep} Hz"
                if dB_cutoff:
                    l_freq_rep = np.array(l_freq, float)
                    if l_freq_rep.size == 1:
                        l_freq_rep = f"{l_freq_rep.item():0.2f}"
                    cutoff_rep = np.array(l_freq - l_trans_bandwidth / 2.0, float)
                    if cutoff_rep.size == 1:
                        cutoff_rep = f"{cutoff_rep.item():0.2f}"
                    # Could be an array
                    logger.info(f"- Lower passband edge: {l_freq_rep}")
                    msg += f" ({dB_cutoff} cutoff frequency: {cutoff_rep} Hz)"
            logger.info(msg)
            l_trans_bandwidth = cast(l_trans_bandwidth)
            if np.any(l_trans_bandwidth <= 0):
                raise ValueError(
                    f"l_trans_bandwidth must be positive, got {l_trans_bandwidth}"
                )
            l_stop = l_freq - l_trans_bandwidth
            if reverse:  # band-stop style
                l_stop += l_trans_bandwidth
                l_freq += l_trans_bandwidth
            if np.any(l_stop < 0):
                raise ValueError(
                    "Filter specification invalid: Lower stop frequency negative ("
                    f"{l_stop:0.2f} Hz). Increase pass frequency or reduce the "
                    "transition bandwidth (l_trans_bandwidth)"
                )
        if h_freq is not None:  # low-pass component
            if isinstance(h_trans_bandwidth, str):
                if h_trans_bandwidth != "auto":
                    raise ValueError(
                        'h_trans_bandwidth must be "auto" if '
                        f'string, got "{h_trans_bandwidth}"'
                    )
                h_trans_bandwidth = np.minimum(
                    np.maximum(0.25 * h_freq, 2.0), sfreq / 2.0 - h_freq
                )
            h_trans_rep = np.array(h_trans_bandwidth, float)
            if h_trans_rep.size == 1:
                h_trans_rep = f"{h_trans_rep.item():0.2f}"
            with np.printoptions(precision=2, floatmode="fixed"):
                msg = f"- Upper transition bandwidth: {h_trans_rep} Hz"
                if dB_cutoff:
                    h_freq_rep = np.array(h_freq, float)
                    if h_freq_rep.size == 1:
                        h_freq_rep = f"{h_freq_rep.item():0.2f}"
                    cutoff_rep = np.array(h_freq + h_trans_bandwidth / 2.0, float)
                    if cutoff_rep.size == 1:
                        cutoff_rep = f"{cutoff_rep.item():0.2f}"
                    logger.info(f"- Upper passband edge: {h_freq_rep} Hz")
                    msg += f" ({dB_cutoff} cutoff frequency: {cutoff_rep} Hz)"
            logger.info(msg)
            h_trans_bandwidth = cast(h_trans_bandwidth)
            if np.any(h_trans_bandwidth <= 0):
                raise ValueError(
                    f"h_trans_bandwidth must be positive, got {h_trans_bandwidth}"
                )
            h_stop = h_freq + h_trans_bandwidth
            if reverse:  # band-stop style
                h_stop -= h_trans_bandwidth
                h_freq -= h_trans_bandwidth
            if np.any(h_stop > sfreq / 2):
                raise ValueError(
                    f"Effective band-stop frequency ({h_stop}) is too high (maximum "
                    f"based on Nyquist is {sfreq / 2.0})"
                )

        if isinstance(filter_length, str) and filter_length.lower() == "auto":
            filter_length = filter_length.lower()
            h_check = l_check = np.inf
            if h_freq is not None:
                h_check = min(np.atleast_1d(h_trans_bandwidth))
            if l_freq is not None:
                l_check = min(np.atleast_1d(l_trans_bandwidth))
            mult_fact = 2.0 if fir_design == "firwin2" else 1.0
            filter_length = f"{_length_factors[fir_window] * mult_fact / float(min(h_check, l_check))}s"  # noqa: E501
            next_pow_2 = False  # disable old behavior
        else:
            next_pow_2 = isinstance(filter_length, str) and phase == "zero-double"

        filter_length = _to_samples(filter_length, sfreq, phase, fir_design)

        # use correct type of filter (must be odd length for firwin and for
        # zero phase)
        if fir_design == "firwin" or phase == "zero":
            filter_length += (filter_length - 1) % 2

        logger.info(
            f"- Filter length: {filter_length} samples ({filter_length / sfreq:0.3f} s)"
        )
        logger.info("")

        if filter_length <= 0:
            raise ValueError(f"filter_length must be positive, got {filter_length}")

        if next_pow_2:
            filter_length = 2 ** int(np.ceil(np.log2(filter_length)))
            if fir_design == "firwin":
                filter_length += (filter_length - 1) % 2

    # If we have data supplied, do a sanity check
    if x is not None:
        x = _check_filterable(x)
        len_x = x.shape[-1]
        if method != "fir":
            filter_length = len_x
        if filter_length > len_x and not (l_freq is None and h_freq is None):
            warn(
                f"filter_length ({filter_length}) is longer than the signal ({len_x}), "
                "distortion is likely. Reduce filter length or filter a longer signal."
            )

    logger.debug(f"Using filter length: {filter_length}")
    return (
        x,
        sfreq,
        l_freq,
        h_freq,
        l_stop,
        h_stop,
        filter_length,
        phase,
        fir_window,
        fir_design,
    )


def _check_resamp_noop(sfreq, o_sfreq, rtol=1e-6):
    if np.isclose(sfreq, o_sfreq, atol=0, rtol=rtol):
        logger.info(
            f"Sampling frequency of the instance is already {sfreq}, returning "
            "unmodified."
        )
        return True
    return False


class FilterMixin:
    """Object for Epoch/Evoked filtering."""

    @verbose
    def savgol_filter(self, h_freq, verbose=None):
        """Filter the data using Savitzky-Golay polynomial method.

        Parameters
        ----------
        h_freq : float
            Approximate high cut-off frequency in Hz. Note that this
            is not an exact cutoff, since Savitzky-Golay filtering
            :footcite:`SavitzkyGolay1964` is done using polynomial fits
            instead of FIR/IIR filtering. This parameter is thus used to
            determine the length of the window over which a 5th-order
            polynomial smoothing is used.
        %(verbose)s

        Returns
        -------
        inst : instance of Epochs, Evoked or SourceEstimate
            The object with the filtering applied.

        See Also
        --------
        mne.io.Raw.filter

        Notes
        -----
        For Savitzky-Golay low-pass approximation, see:

            https://gist.github.com/larsoner/bbac101d50176611136b

        When working on SourceEstimates the sample rate of the original data is inferred from tstep.

        .. versionadded:: 0.9.0

        References
        ----------
        .. footbibliography::

        Examples
        --------
        >>> import mne
        >>> from os import path as op
        >>> evoked_fname = op.join(mne.datasets.sample.data_path(), 'MEG', 'sample', 'sample_audvis-ave.fif')  # doctest:+SKIP
        >>> evoked = mne.read_evokeds(evoked_fname, baseline=(None, 0))[0]  # doctest:+SKIP
        >>> evoked.savgol_filter(10.)  # low-pass at around 10 Hz # doctest:+SKIP
        >>> evoked.plot()  # doctest:+SKIP
        """  # noqa: E501
        from .source_estimate import _BaseSourceEstimate

        _check_preload(self, "inst.savgol_filter")
        if not isinstance(self, _BaseSourceEstimate):
            s_freq = self.info["sfreq"]
        else:
            s_freq = 1 / self.tstep
        h_freq = float(h_freq)
        if h_freq >= s_freq / 2.0:
            raise ValueError("h_freq must be less than half the sample rate")

        # savitzky-golay filtering
        window_length = (int(np.round(s_freq / h_freq)) // 2) * 2 + 1
        logger.info("Using savgol length %d", window_length)
        self._data[:] = signal.savgol_filter(
            self._data, axis=-1, polyorder=5, window_length=window_length
        )
        return self

    @verbose
    def filter(
        self,
        l_freq,
        h_freq,
        picks=None,
        filter_length="auto",
        l_trans_bandwidth="auto",
        h_trans_bandwidth="auto",
        n_jobs=None,
        method="fir",
        iir_params=None,
        phase="zero",
        fir_window="hamming",
        fir_design="firwin",
        skip_by_annotation=("edge", "bad_acq_skip"),
        pad="edge",
        *,
        verbose=None,
    ):
        """Filter a subset of channels/vertices.

        Parameters
        ----------
        %(l_freq)s
        %(h_freq)s
        %(picks_all_data)s
        %(filter_length)s
        %(l_trans_bandwidth)s
        %(h_trans_bandwidth)s
        %(n_jobs_fir)s
        %(method_fir)s
        %(iir_params)s
        %(phase)s
        %(fir_window)s
        %(fir_design)s
        %(skip_by_annotation)s

            .. versionadded:: 0.16.
        %(pad_fir)s
        %(verbose)s

        Returns
        -------
        inst : instance of Epochs, Evoked, SourceEstimate, or Raw
            The filtered data.

        See Also
        --------
        mne.filter.create_filter
        mne.Evoked.savgol_filter
        mne.io.Raw.notch_filter
        mne.io.Raw.resample
        mne.filter.create_filter
        mne.filter.filter_data
        mne.filter.construct_iir_filter

        Notes
        -----
        Applies a zero-phase low-pass, high-pass, band-pass, or band-stop
        filter to the channels selected by ``picks``.
        The data are modified inplace.

        The object has to have the data loaded e.g. with ``preload=True``
        or ``self.load_data()``.

        ``l_freq`` and ``h_freq`` are the frequencies below which and above
        which, respectively, to filter out of the data. Thus the uses are:

            * ``l_freq < h_freq``: band-pass filter
            * ``l_freq > h_freq``: band-stop filter
            * ``l_freq is not None and h_freq is None``: high-pass filter
            * ``l_freq is None and h_freq is not None``: low-pass filter

        ``self.info['lowpass']`` and ``self.info['highpass']`` are only
        updated with picks=None.

        .. note:: If n_jobs > 1, more memory is required as
                  ``len(picks) * n_times`` additional time points need to
                  be temporarily stored in memory.

        When working on SourceEstimates the sample rate of the original
        data is inferred from tstep.

        For more information, see the tutorials
        :ref:`disc-filtering` and :ref:`tut-filter-resample` and
        :func:`mne.filter.create_filter`.

        .. versionadded:: 0.15
        """
        from .annotations import _annotations_starts_stops
        from .io import BaseRaw
        from .source_estimate import _BaseSourceEstimate

        _check_preload(self, "inst.filter")
        if not isinstance(self, _BaseSourceEstimate):
            update_info, picks = _filt_check_picks(self.info, picks, l_freq, h_freq)
            s_freq = self.info["sfreq"]
        else:
            s_freq = 1.0 / self.tstep
        if pad is None and method != "iir":
            pad = "edge"
        if isinstance(self, BaseRaw):
            # Deal with annotations
            onsets, ends = _annotations_starts_stops(
                self, skip_by_annotation, invert=True
            )
            logger.info(
                "Filtering raw data in %d contiguous segment%s",
                len(onsets),
                _pl(onsets),
            )
        else:
            onsets, ends = np.array([0]), np.array([self._data.shape[1]])
        max_idx = (ends - onsets).argmax()
        for si, (start, stop) in enumerate(zip(onsets, ends)):
            # Only output filter params once (for info level), and only warn
            # once about the length criterion (longest segment is too short)
            use_verbose = verbose if si == max_idx else "error"
            filter_data(
                self._data[:, start:stop],
                s_freq,
                l_freq,
                h_freq,
                picks,
                filter_length,
                l_trans_bandwidth,
                h_trans_bandwidth,
                n_jobs,
                method,
                iir_params,
                copy=False,
                phase=phase,
                fir_window=fir_window,
                fir_design=fir_design,
                pad=pad,
                verbose=use_verbose,
            )
        # update info if filter is applied to all data channels/vertices,
        # and it's not a band-stop filter
        if not isinstance(self, _BaseSourceEstimate):
            _filt_update_info(self.info, update_info, l_freq, h_freq)
        return self

    @verbose
    def resample(
        self,
        sfreq,
        *,
        npad="auto",
        window="auto",
        n_jobs=None,
        pad="edge",
        method="fft",
        verbose=None,
    ):
        """Resample data.

        If appropriate, an anti-aliasing filter is applied before resampling.
        See :ref:`resampling-and-decimating` for more information.

        .. note:: Data must be loaded.

        Parameters
        ----------
        sfreq : float
            New sample rate to use.
        %(npad)s
        %(window_resample)s
        %(n_jobs_cuda)s
        %(pad_resample)s

            .. versionadded:: 0.15
        %(method_resample)s

            .. versionadded:: 1.7
        %(verbose)s

        Returns
        -------
        inst : instance of Epochs or Evoked
            The resampled object.

        See Also
        --------
        mne.io.Raw.resample

        Notes
        -----
        For some data, it may be more accurate to use npad=0 to reduce
        artifacts. This is dataset dependent -- check your data!
        """
        from .epochs import BaseEpochs
        from .evoked import Evoked

        # Should be guaranteed by our inheritance, and the fact that
        # mne.io.BaseRaw and _BaseSourceEstimate overrides this method
        assert isinstance(self, BaseEpochs | Evoked)

        sfreq = float(sfreq)
        o_sfreq = self.info["sfreq"]
        if _check_resamp_noop(sfreq, o_sfreq):
            return self

        _check_preload(self, "inst.resample")
        self._data = resample(
            self._data,
            sfreq,
            o_sfreq,
            npad=npad,
            window=window,
            n_jobs=n_jobs,
            pad=pad,
            method=method,
        )
        lowpass = self.info.get("lowpass")
        lowpass = np.inf if lowpass is None else lowpass
        with self.info._unlock():
            self.info["lowpass"] = min(lowpass, sfreq / 2.0)
            self.info["sfreq"] = float(sfreq)
        new_times = (
            np.arange(self._data.shape[-1], dtype=np.float64) / sfreq + self.times[0]
        )
        # adjust indirectly affected variables
        self._set_times(new_times)
        self._raw_times = self.times
        self._update_first_last()
        return self

    @verbose
    def apply_hilbert(
        self, picks=None, envelope=False, n_jobs=None, n_fft="auto", *, verbose=None
    ):
        """Compute analytic signal or envelope for a subset of channels/vertices.

        Parameters
        ----------
        %(picks_all_data_noref)s
        envelope : bool
            Compute the envelope signal of each channel/vertex. Default False.
            See Notes.
        %(n_jobs)s
        n_fft : int | None | str
            Points to use in the FFT for Hilbert transformation. The signal
            will be padded with zeros before computing Hilbert, then cut back
            to original length. If None, n == self.n_times. If 'auto',
            the next highest fast FFT length will be use.
        %(verbose)s

        Returns
        -------
        self : instance of Raw, Epochs, Evoked or SourceEstimate
            The raw object with transformed data.

        Notes
        -----
        **Parameters**

        If ``envelope=False``, the analytic signal for the channels/vertices defined in
        ``picks`` is computed and the data of the Raw object is converted to
        a complex representation (the analytic signal is complex valued).

        If ``envelope=True``, the absolute value of the analytic signal for the
        channels/vertices defined in ``picks`` is computed, resulting in the envelope
        signal.

        .. warning: Do not use ``envelope=True`` if you intend to compute
                    an inverse solution from the raw data. If you want to
                    compute the envelope in source space, use
                    ``envelope=False`` and compute the envelope after the
                    inverse solution has been obtained.

        If envelope=False, more memory is required since the original raw data
        as well as the analytic signal have temporarily to be stored in memory.
        If n_jobs > 1, more memory is required as ``len(picks) * n_times``
        additional time points need to be temporarily stored in memory.

        Also note that the ``n_fft`` parameter will allow you to pad the signal
        with zeros before performing the Hilbert transform. This padding
        is cut off, but it may result in a slightly different result
        (particularly around the edges). Use at your own risk.

        **Analytic signal**

        The analytic signal "x_a(t)" of "x(t)" is::

            x_a = F^{-1}(F(x) 2U) = x + i y

        where "F" is the Fourier transform, "U" the unit step function,
        and "y" the Hilbert transform of "x". One usage of the analytic
        signal is the computation of the envelope signal, which is given by
        "e(t) = abs(x_a(t))". Due to the linearity of Hilbert transform and the
        MNE inverse solution, the enevlope in source space can be obtained
        by computing the analytic signal in sensor space, applying the MNE
        inverse, and computing the envelope in source space.
        """
        from .source_estimate import _BaseSourceEstimate

        if not isinstance(self, _BaseSourceEstimate):
            use_info = self.info
        else:
            use_info = len(self._data)
        _check_preload(self, "inst.apply_hilbert")
        picks = _picks_to_idx(use_info, picks, exclude=(), with_ref_meg=False)

        if n_fft is None:
            n_fft = len(self.times)
        elif isinstance(n_fft, str):
            if n_fft != "auto":
                raise ValueError(
                    f"n_fft must be an integer, string, or None, got {type(n_fft)}"
                )
            n_fft = next_fast_len(len(self.times))
        n_fft = int(n_fft)
        if n_fft < len(self.times):
            raise ValueError(
                f"n_fft ({n_fft}) must be at least the number of time points ("
                f"{len(self.times)})"
            )
        dtype = None if envelope else np.complex128
        args, kwargs = (), dict(n_fft=n_fft, envelope=envelope)

        data_in = self._data
        if dtype is not None and dtype != self._data.dtype:
            self._data = self._data.astype(dtype)

        parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs)
        if n_jobs == 1:
            # modify data inplace to save memory
            for idx in picks:
                self._data[..., idx, :] = _check_fun(
                    _my_hilbert, data_in[..., idx, :], *args, **kwargs
                )
        else:
            # use parallel function
            data_picks_new = parallel(
                p_fun(_my_hilbert, data_in[..., p, :], *args, **kwargs) for p in picks
            )
            for pp, p in enumerate(picks):
                self._data[..., p, :] = data_picks_new[pp]
        return self


def _check_fun(fun, d, *args, **kwargs):
    """Check shapes."""
    want_shape = d.shape
    d = fun(d, *args, **kwargs)
    if not isinstance(d, np.ndarray):
        raise TypeError("Return value must be an ndarray")
    if d.shape != want_shape:
        raise ValueError(f"Return data must have shape {want_shape} not {d.shape}")
    return d


def _my_hilbert(x, n_fft=None, envelope=False):
    """Compute Hilbert transform of signals w/ zero padding.

    Parameters
    ----------
    x : array, shape (n_times)
        The signal to convert
    n_fft : int
        Size of the FFT to perform, must be at least ``len(x)``.
        The signal will be cut back to original length.
    envelope : bool
        Whether to compute amplitude of the hilbert transform in order
        to return the signal envelope.

    Returns
    -------
    out : array, shape (n_times)
        The hilbert transform of the signal, or the envelope.
    """
    n_x = x.shape[-1]
    out = signal.hilbert(x, N=n_fft, axis=-1)[..., :n_x]
    if envelope:
        out = np.abs(out)
    return out


@verbose
def design_mne_c_filter(
    sfreq,
    l_freq=None,
    h_freq=40.0,
    l_trans_bandwidth=None,
    h_trans_bandwidth=5.0,
    verbose=None,
):
    """Create a FIR filter like that used by MNE-C.

    Parameters
    ----------
    sfreq : float
        The sample frequency.
    l_freq : float | None
        The low filter frequency in Hz, default None.
        Can be None to avoid high-passing.
    h_freq : float
        The high filter frequency in Hz, default 40.
        Can be None to avoid low-passing.
    l_trans_bandwidth : float | None
        Low transition bandwidthin Hz. Can be None (default) to use 3 samples.
    h_trans_bandwidth : float
        High transition bandwidth in Hz.
    %(verbose)s

    Returns
    -------
    h : ndarray, shape (8193,)
        The linear-phase (symmetric) FIR filter coefficients.

    Notes
    -----
    This function is provided mostly for reference purposes.

    MNE-C uses a frequency-domain filter design technique by creating a
    linear-phase filter of length 8193. In the frequency domain, the
    4197 frequencies are directly constructed, with zeroes in the stop-band
    and ones in the passband, with squared cosine ramps in between.
    """
    n_freqs = (4096 + 2 * 2048) // 2 + 1
    freq_resp = np.ones(n_freqs)
    l_freq = 0 if l_freq is None else float(l_freq)
    if l_trans_bandwidth is None:
        l_width = 3
    else:
        l_width = (int(((n_freqs - 1) * l_trans_bandwidth) / (0.5 * sfreq)) + 1) // 2
    l_start = int(((n_freqs - 1) * l_freq) / (0.5 * sfreq))
    h_freq = sfreq / 2.0 if h_freq is None else float(h_freq)
    h_width = (int(((n_freqs - 1) * h_trans_bandwidth) / (0.5 * sfreq)) + 1) // 2
    h_start = int(((n_freqs - 1) * h_freq) / (0.5 * sfreq))
    logger.info(
        "filter : %7.3f ... %6.1f Hz   bins : %d ... %d of %d hpw : %d lpw : %d",
        l_freq,
        h_freq,
        l_start,
        h_start,
        n_freqs,
        l_width,
        h_width,
    )
    if l_freq > 0:
        start = l_start - l_width + 1
        stop = start + 2 * l_width - 1
        if start < 0 or stop >= n_freqs:
            raise RuntimeError("l_freq too low or l_trans_bandwidth too large")
        freq_resp[:start] = 0.0
        k = np.arange(-l_width + 1, l_width) / float(l_width) + 3.0
        freq_resp[start:stop] = np.cos(np.pi / 4.0 * k) ** 2

    if h_freq < sfreq / 2.0:
        start = h_start - h_width + 1
        stop = start + 2 * h_width - 1
        if start < 0 or stop >= n_freqs:
            raise RuntimeError("h_freq too high or h_trans_bandwidth too large")
        k = np.arange(-h_width + 1, h_width) / float(h_width) + 1.0
        freq_resp[start:stop] *= np.cos(np.pi / 4.0 * k) ** 2
        freq_resp[stop:] = 0.0
    # Get the time-domain version of this signal
    h = fft.irfft(freq_resp, n=2 * len(freq_resp) - 1)
    h = np.roll(h, n_freqs - 1)  # center the impulse like a linear-phase filt
    return h


def _filt_check_picks(info, picks, h_freq, l_freq):
    update_info = False
    # This will pick *all* data channels
    picks = _picks_to_idx(info, picks, "data_or_ica", exclude=())
    if h_freq is not None or l_freq is not None:
        data_picks = _picks_to_idx(
            info, None, "data_or_ica", exclude=(), allow_empty=True
        )
        if len(data_picks) == 0:
            logger.info(
                "No data channels found. The highpass and "
                "lowpass values in the measurement info will not "
                "be updated."
            )
        elif np.isin(data_picks, picks).all():
            update_info = True
        else:
            logger.info(
                "Filtering a subset of channels. The highpass and "
                "lowpass values in the measurement info will not "
                "be updated."
            )
    return update_info, picks


def _filt_update_info(info, update_info, l_freq, h_freq):
    if update_info:
        if (
            h_freq is not None
            and (l_freq is None or l_freq < h_freq)
            and (info["lowpass"] is None or h_freq < info["lowpass"])
        ):
            with info._unlock():
                info["lowpass"] = float(h_freq)
        if (
            l_freq is not None
            and (h_freq is None or l_freq < h_freq)
            and (info["highpass"] is None or l_freq > info["highpass"])
        ):
            with info._unlock():
                info["highpass"] = float(l_freq)


def _iir_pad_apply_unpad(x, *, func, padlen, padtype, **kwargs):
    x_out = np.reshape(x, (-1, x.shape[-1])).copy()
    for this_x in x_out:
        x_ext = this_x
        if padlen:
            x_ext = _smart_pad(x_ext, (padlen, padlen), padtype)
        x_ext = func(x=x_ext, axis=-1, padlen=0, **kwargs)
        this_x[:] = x_ext[padlen : len(x_ext) - padlen]
    x_out.shape = x.shape
    return x_out
